{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Inversion attack\n",
    "The past few years have seen an explosion of machine learning models to applied to all areas of industry.\n",
    "But just how much thought is given to security and safety of a model and the data upon which it's trained?\n",
    "There are a myriad of ways that a model and its data can be compromised.\n",
    "In this tutorial,\n",
    "we will be running a simple **model inversion attack** on a classifier trained on MNIST.\n",
    "\n",
    "## What is model inversion?\n",
    "The aim of a model inversion attack is to recreate some data which was input to the model.\n",
    "If such an attack is successful,\n",
    "the consequences for a model or data holder could be disastrous,\n",
    "for their reputation\n",
    "and the possible legal liability of failing to adequately protect data.\n",
    "\n",
    "## How to we do it?\n",
    "Let's say there is an externally hosted,\n",
    "trained model we wish to attack (the **target model**).\n",
    "As we don't have a copy of the model,\n",
    "we don't have any knowledge of the model parameters which we can utilise.\n",
    "The best we can do is query the model by sending data to it to be classified.\n",
    "This type of attack setting is known as **black box**,\n",
    "because we know nothing about what's going on inside the model.\n",
    "\n",
    "What we want to do is to train our own model\n",
    "(the **attack model**)\n",
    "which takes some output from the target model\n",
    "and recreates the data which was fed to it.\n",
    "Our attack model is essentially performing the same role as a _decoder_ in an autoencoder,\n",
    "except in our case the _encoder_\n",
    "(the target model)\n",
    "has already been trained.\n",
    "\n",
    "As we can query the target with our own data,\n",
    "we can easily create a dataset of (target input, target output) data on which to train our attack model.\n",
    "But what should this data look like?\n",
    "If our target input data is too dissimilar to the data on which the target has been trained (for example, images of dogs applied to a model which classifies human faces),\n",
    "it is likely that the data will be encoded into some small,\n",
    "obscure part of the target's output space.\n",
    "When we run the attack on data more typical of the target's intended task,\n",
    "our attack model will not have learned how to turn output relating to a human face back into the human face.\n",
    "Therefore,\n",
    "we want our attacker training dataset to be _as close to the target training data_ as possible,\n",
    "however we do not need to use actual training data.\n",
    "In practice,\n",
    "we can use our knowledge about the task a model has been trained to complete to find data which may look like its training data.\n",
    "\n",
    "In this tutorial,\n",
    "let's consider the case of a model which is distributed across multiple parties:\n",
    "the first half of the model is located on one device;\n",
    "this model part performs some inference on the data\n",
    "and sends its output to the second device,\n",
    "where the inference is finished.\n",
    "This paradigm is increasingly common for models which run on mobile devices,\n",
    "as output from a model layer can be smaller than the input data and therefore easier to send to a central server;\n",
    "you may also recognise this process from [SplitNN](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials/advanced/split_neural_network),\n",
    "a technique which preserves some data privacy\n",
    "as raw data does not need to be sent to an untrusted party.\n",
    "\n",
    "You can read more about the black box model inversion attack\n",
    "[here](https://ieeexplore.ieee.org/abstract/document/8835269).\n",
    "\n",
    "## Tutorial\n",
    "In this tutorial we:\n",
    "* Train a simple convolutional neural network (CNN), split into two parts, to classify images of handwritten digits (the MNIST dataset)\n",
    "* Train an attack model to invert output of the first part of the model back into images of handwritten digits\n",
    "* Run our attack model on unseen images\n",
    "\n",
    "Authors:\n",
    "* Tom Titcombe - Github: [@TTitcombe](https://github.com/TTitcombe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import EMNIST, MNIST\n",
    "from tqdm.notebook import tqdm, trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define hyperparameters for training the target model and the attack model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = namedtuple('hyperparams', 'batch_size,epochs,learning_rate,n_data')\n",
    "\n",
    "# Target model hyperparameters\n",
    "target_hyperparams = hyperparams(\n",
    "    batch_size=256,\n",
    "    epochs=10,\n",
    "    learning_rate=1e-4,\n",
    "    n_data=20_000,  # We don't need all the training data to get a decent MNIST classifier\n",
    ")\n",
    "\n",
    "# Attack model hyperparameters\n",
    "attacker_hyperparams = hyperparams(\n",
    "    batch_size=32,\n",
    "    epochs=10,\n",
    "    learning_rate=1e-4,\n",
    "    n_data=500,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_transform = transforms.Compose(\n",
    "    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]\n",
    ")\n",
    "\n",
    "# Target model training data\n",
    "train_data = MNIST(\"mnist\", train=True, download=True, transform=mnist_transform)\n",
    "\n",
    "# We don't need to use all the training data for MNIST as it's a simple dataset\n",
    "train_data.data = train_data.data[:target_hyperparams.n_data]\n",
    "train_data.targets = train_data.targets[:target_hyperparams.n_data]\n",
    "\n",
    "# Target model test data\n",
    "test_data = MNIST(\"mnist\", train=False, download=True, transform=mnist_transform)\n",
    "\n",
    "# Create data loaders\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=target_hyperparams.batch_size)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=1_000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Step 1: Create a basic classifier\n",
    "MNIST is not a complex dataset, so we do not need a large model. We create a single model for the purposes of this tutorial, but split the computation into two separate stages. In practice, the two stages would be hosted on different devices and the output of one stage would be communicated to the second device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Classifier(torch.nn.Module):\n",
    "    def __init__(self, first_network, second_network) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        # --- First stage --- #\n",
    "        self.stage1 = first_network\n",
    "        \n",
    "        # --- Second stage --- #\n",
    "        # In practice, at this point the output of the previous stage would be transmitted to\n",
    "        # a central server, where inference would continue\n",
    "        self.stage2 = second_network\n",
    "\n",
    "    def mobile_stage(self, x):\n",
    "        return self.stage1(x)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.mobile_stage(x)\n",
    "        out = out.view(out.size(0), -1)\n",
    "\n",
    "        return self.stage2(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The first part of the network. This would be hosted on a mobile device\n",
    "first_network = torch.nn.Sequential(\n",
    "                torch.nn.Conv2d(1, 32, kernel_size=5, padding=0, stride=1),  # first Conv layer\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.MaxPool2d(kernel_size=2),\n",
    "                torch.nn.Conv2d(32, 32, kernel_size=5, padding=0, stride=1),  # second Conv layer\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.MaxPool2d(kernel_size=2),\n",
    "        )\n",
    "\n",
    "\n",
    "# The second and final part of the network. This would be typically hosted on a central server in practice\n",
    "second_network = torch.nn.Sequential(\n",
    "                torch.nn.Linear(512, 256),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(256, 10),  # 10-class output\n",
    "                torch.nn.Softmax(dim=-1),\n",
    "        )\n",
    "\n",
    "target_model = Classifier(first_network, second_network)\n",
    "target_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train target model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(target_model.parameters(), lr=target_hyperparams.learning_rate)\n",
    "loss_criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in trange(target_hyperparams.epochs):\n",
    "    train_correct = 0\n",
    "    train_loss = 0.\n",
    "\n",
    "    # Training loop\n",
    "    for data, targets in train_loader:\n",
    "        optim.zero_grad()\n",
    "\n",
    "        output = target_model(data)\n",
    "\n",
    "        # Update network\n",
    "        loss = loss_criterion(output, targets)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "\n",
    "        # Track training statistics\n",
    "        _, predicted = output.max(1)\n",
    "        train_correct += predicted.eq(targets).sum().item()\n",
    "        train_loss += loss.item()\n",
    "\n",
    "train_loss /= len(train_data)\n",
    "\n",
    "# Check test accuracy\n",
    "test_correct = 0\n",
    "test_loss = 0.\n",
    "\n",
    "for data, targets in test_loader:\n",
    "    with torch.no_grad():\n",
    "        output = target_model(data)\n",
    "\n",
    "    loss = loss_criterion(output, targets)\n",
    "\n",
    "    _, predicted = output.max(1)\n",
    "    test_correct += predicted.eq(targets).sum().item()\n",
    "    test_loss += loss.item()\n",
    "\n",
    "test_loss /= len(test_data)\n",
    "\n",
    "print(\n",
    "    f\"Training loss: {train_loss:.3f}\\n\"\n",
    "    f\"Test loss: {test_loss:.3f}\"\n",
    ")\n",
    "\n",
    "print(\n",
    "    f\"Training accuracy: {100 * train_correct / target_hyperparams.n_data:.3f}\\n\"\n",
    "    f\"Test accuracy: {100 * test_correct / len(test_data):.3f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Step 2: Train an attack model\n",
    "You will notice that our attack model,\n",
    "below,\n",
    "is made up a deconvolutional layers.\n",
    "This leverages our additional knowledge that the input data are images,\n",
    "however this attack can work with other types of layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttackModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = torch.nn.Sequential(\n",
    "                torch.nn.ConvTranspose2d(\n",
    "                    in_channels=32,\n",
    "                    out_channels=32,\n",
    "                    kernel_size=7,\n",
    "                    padding=1,\n",
    "                    stride=2,\n",
    "                    output_padding=1,\n",
    "                ),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.ConvTranspose2d(\n",
    "                    in_channels=32,\n",
    "                    out_channels=32,\n",
    "                    kernel_size=5,\n",
    "                    padding=1,\n",
    "                    stride=2,\n",
    "                    output_padding=1,\n",
    "                ),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.ConvTranspose2d(\n",
    "                    in_channels=32, out_channels=1, kernel_size=5, padding=1, stride=1,\n",
    "                ),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.layers(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collecting attacker training data\n",
    "We may not know _exactly_ what the training data looked like,\n",
    "but perhaps we find out that the engineers used a dataset of some kind of handwriting.\n",
    "With this knowledge,\n",
    "we can use the\n",
    "[`EMNIST`](https://www.nist.gov/itl/products-and-services/emnist-dataset)\n",
    "dataset of handwritten letters to create our attack dataset;\n",
    "hopefully the images in this dataset are similar enough to the target's training data\n",
    "that the attack will still be successful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get attacker\n",
    "attacker = AttackModel()\n",
    "\n",
    "# Get dataset\n",
    "attacker_dataset = EMNIST(\"emnist\", \"letters\", download=True, train=False, transform=mnist_transform)\n",
    "\n",
    "# Use the last n_data images in the test set to train the attacker\n",
    "attacker_dataset.data = attacker_dataset.data[:attacker_hyperparams.n_data]\n",
    "attacker_dataset.targets = attacker_dataset.targets[:attacker_hyperparams.n_data]\n",
    "\n",
    "attacker_train_loader = torch.utils.data.DataLoader(attacker_dataset, batch_size=attacker_hyperparams.batch_size)\n",
    "\n",
    "# Train attacker\n",
    "attack_optim = torch.optim.Adam(attacker.parameters(), lr=attacker_hyperparams.learning_rate)\n",
    "\n",
    "for epoch in trange(attacker_hyperparams.epochs):\n",
    "    for data, targets in attacker_train_loader:\n",
    "        data.float()\n",
    "        targets.float()\n",
    "\n",
    "        attack_optim.zero_grad()\n",
    "\n",
    "        # We intercept the output of the mobile device's model\n",
    "        # This is the input of our attack model\n",
    "        with torch.no_grad():\n",
    "            attack_input = target_model.mobile_stage(data)\n",
    "\n",
    "        output = attacker(attack_input)\n",
    "\n",
    "        loss = ((output - data)**2).mean()  # We want our reconstructed image to look as much like the original image as possible\n",
    "        loss.backward()\n",
    "        attack_optim.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Step 3: Attack the target\n",
    "With our attack model trained,\n",
    "we can extract images from the target model.\n",
    "The next time somebody uses the target,\n",
    "we will intercept the output of the first model stage\n",
    "and hopefully be able to generate our victim's data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_images(\n",
    "    tensors,\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot normalised MNIST tensors as images\n",
    "    \"\"\"\n",
    "    fig = plt.figure(figsize=(10, 5))\n",
    "\n",
    "    n_tensors = len(tensors)\n",
    "    n_cols = min(n_tensors, 4)\n",
    "    n_rows = int((n_tensors - 1) / 4) + 1\n",
    "\n",
    "    # De-normalise an MNIST tensor\n",
    "    mu = torch.tensor([0.1307], dtype=torch.float32)\n",
    "    sigma = torch.tensor([0.3081], dtype=torch.float32)\n",
    "    Unnormalise = transforms.Normalize((-mu / sigma).tolist(), (1.0 / sigma).tolist())\n",
    "\n",
    "    for row in range(n_rows):\n",
    "        for col in range(n_cols):\n",
    "            idx = n_cols * row + col\n",
    "\n",
    "            if idx > n_tensors - 1:\n",
    "                break\n",
    "\n",
    "            ax = fig.add_subplot(n_rows, n_cols, idx + 1)\n",
    "            tensor = Unnormalise(tensors[idx])\n",
    "\n",
    "            # Clip image values so we can plot\n",
    "            tensor[tensor < 0] = 0\n",
    "            tensor[tensor > 1] = 1\n",
    "\n",
    "            tensor = tensor.squeeze(0)  # remove batch dim\n",
    "\n",
    "            ax.imshow(transforms.ToPILImage()(tensor), interpolation=\"bicubic\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attack(attack_model, target_model, dataset, idxs):\n",
    "    images = []\n",
    "\n",
    "    for datum_idx in idxs:\n",
    "        actual_image, _ = dataset[datum_idx]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            target_output = target_model.mobile_stage(actual_image.unsqueeze(0))\n",
    "            reconstructed_image = attack_model(target_output).squeeze(0)\n",
    "\n",
    "        images.append(actual_image)\n",
    "        images.append(reconstructed_image)\n",
    "\n",
    "    plot_images(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attack(attacker, target_model, test_data, range(6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that we have successfully recreated the input images\n",
    "without having access to any training data.\n",
    "While handwritten digits are not the most private of data,\n",
    "this attack could also be applied\n",
    "to more sensitive data,\n",
    "such as face images\n",
    "or MRI records.\n",
    "\n",
    "## Extension\n",
    "* Try to invert output from layers deeper into the target model as well as the classification probabilities. What do you notice about an attacker's capacity to invert data?\n",
    "* Consider simple, practical steps that can be taken to protect a model from a model inversion attack. In particular, what could be done to the model, and to the way in which the model is exposed to the world? "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
